import importlib.util
import inspect
import json
import multiprocessing as mp
import os
import time
from functools import partial

import pytest
from pydantic import BaseModel

from letta.functions.functions import derive_openai_json_schema
from letta.functions.schema_generator import validate_google_style_docstring
from letta.llm_api.helpers import convert_to_structured_output, make_post_request
from letta.schemas.tool import Tool, ToolCreate


def _clean_diff(d1, d2):
    """Utility function to clean up the diff between two dictionaries."""

    # Keys in d1 but not in d2
    removed = {k: d1[k] for k in d1.keys() - d2.keys()}

    # Keys in d2 but not in d1
    added = {k: d2[k] for k in d2.keys() - d1.keys()}

    # Keys in both but values changed
    changed = {k: (d1[k], d2[k]) for k in d1.keys() & d2.keys() if d1[k] != d2[k]}

    return {k: v for k, v in {"removed": removed, "added": added, "changed": changed}.items() if v}  # Only include non-empty differences


def _compare_schemas(generated_schema: dict, expected_schema: dict, strip_heartbeat: bool = True):
    """Compare an autogenerated schema to an expected schema."""

    if strip_heartbeat:
        # Pop out the heartbeat parameter
        del generated_schema["parameters"]["properties"]["request_heartbeat"]
        # Remove from the required list
        generated_schema["parameters"]["required"].remove("request_heartbeat")

    # Check that the two schemas are equal
    # If not, pretty print the difference by dumping with indent=4
    if generated_schema != expected_schema:
        print("==== GENERATED SCHEMA ====")
        print(json.dumps(generated_schema, indent=4))
        print("==== EXPECTED SCHEMA ====")
        print(json.dumps(expected_schema, indent=4))
        print("==== DIFF ====")
        print(json.dumps(_clean_diff(generated_schema, expected_schema), indent=4))
        raise AssertionError("Schemas are not equal")
    else:
        print("Schemas are equal")


def _run_schema_test(schema_name: str, desired_function_name: str, expect_structured_output_fail: bool = False):
    """Load a file and compare the autogenerated schema to the expected schema."""

    # Open the python file as a string
    # Use the absolute path to make it easier to run the test from the root directory
    with open(os.path.join(os.path.dirname(__file__), f"test_tool_schema_parsing_files/{schema_name}.py"), "r") as file:
        source_code = file.read()

    # Derive the schema
    schema = derive_openai_json_schema(source_code, name=desired_function_name)

    # Assert that the schema matches the expected schema
    with open(os.path.join(os.path.dirname(__file__), f"test_tool_schema_parsing_files/{schema_name}.json"), "r") as file:
        expected_schema = json.load(file)

    _compare_schemas(schema, expected_schema, False)

    # Convert to structured output and compare
    if expect_structured_output_fail:
        with pytest.raises(ValueError):
            structured_output = convert_to_structured_output(schema)

    else:
        structured_output = convert_to_structured_output(schema)

        with open(os.path.join(os.path.dirname(__file__), f"test_tool_schema_parsing_files/{schema_name}_so.json"), "r") as file:
            expected_structured_output = json.load(file)

        _compare_schemas(structured_output, expected_structured_output, strip_heartbeat=False)

    return (schema_name, True)  # Return success status


def test_derive_openai_json_schema():
    """Test that the schema generator works across a variety of example source code inputs."""

    # Define test cases
    test_cases = [
        ("pydantic_as_single_arg_example", "create_step", False),
        ("list_of_pydantic_example", "create_task_plan", False),
        ("nested_pydantic_as_arg_example", "create_task_plan", False),
        ("simple_d20", "roll_d20", False),
        ("all_python_complex", "check_order_status", True),
        ("all_python_complex_nodict", "check_order_status", False),
    ]

    # Create a multiprocessing pool
    pool = mp.Pool(processes=min(mp.cpu_count(), len(test_cases)))

    # Run tests in parallel
    results = []
    for schema_name, function_name, expect_fail in test_cases:
        print(f"==== TESTING {schema_name} ====")
        # Use apply_async for non-blocking parallel execution
        result = pool.apply_async(_run_schema_test, args=(schema_name, function_name, expect_fail))
        results.append((schema_name, result))

    # Collect results and check for failures
    for schema_name, result in results:
        try:
            schema_name_result, success = result.get(timeout=60)  # Wait for the result with timeout
            assert success, f"Test for {schema_name} failed"
            print(f"Test for {schema_name} passed")
        except Exception as e:
            print(f"Test for {schema_name} failed with error: {str(e)}")
            raise

    # Close the pool
    pool.close()
    pool.join()


def _openai_payload(test_config):
    """Create an OpenAI payload with a tool call.

    Args:
        test_config: A tuple containing (filename, model, structured_output)

    Returns:
        A tuple of (filename, model, structured_output, success, error_message)
    """
    filename, model, structured_output = test_config
    success = False
    error_message = None

    try:
        # Load schema
        with open(os.path.join(os.path.dirname(__file__), f"test_tool_schema_parsing_files/{filename}.py"), "r") as file:
            source_code = file.read()

        schema = derive_openai_json_schema(source_code)

        # Check if we expect the conversion to fail
        if filename == "all_python_complex" and structured_output:
            try:
                convert_to_structured_output(schema)
                error_message = "Expected ValueError for all_python_complex with structured_output=True"
                return (filename, model, structured_output, False, error_message)
            except ValueError:
                # This is expected
                success = True
                return (filename, model, structured_output, success, error_message)

        # Generate tool schema
        if structured_output:
            tool_schema = convert_to_structured_output(schema)
        else:
            tool_schema = schema

        api_key = os.getenv("OPENAI_API_KEY")
        assert api_key is not None, "OPENAI_API_KEY must be set"

        # Simple system prompt to encourage the LLM to jump directly to a tool call
        system_prompt = "You job is to test the tool that you've been provided. Don't ask for any clarification on the args, just come up with some dummy data and try executing the tool."

        url = "https://api.openai.com/v1/chat/completions"
        headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
        data = {
            "model": model,
            "messages": [
                {"role": "system", "content": system_prompt},
            ],
            "tools": [
                {
                    "type": "function",
                    "function": tool_schema,
                }
            ],
            "tool_choice": "auto",
            "parallel_tool_calls": False,
        }

        make_post_request(url, headers, data)
        success = True

    except Exception as e:
        error_message = str(e)

    return (filename, model, structured_output, success, error_message)


@pytest.mark.parametrize("openai_model", ["gpt-4o"])
@pytest.mark.parametrize("structured_output", [True, False])
def test_valid_schemas_via_openai(openai_model: str, structured_output: bool):
    """Test that we can send the schemas to OpenAI and get a tool call back."""

    start_time = time.time()

    # Define all test configurations
    filenames = [
        "pydantic_as_single_arg_example",
        "list_of_pydantic_example",
        "nested_pydantic_as_arg_example",
        "simple_d20",
        "all_python_complex",
        "all_python_complex_nodict",
    ]

    test_configs = []
    for filename in filenames:
        test_configs.append((filename, openai_model, structured_output))

    # Run tests in parallel using a process pool (more efficient for API calls)
    pool = mp.Pool(processes=min(mp.cpu_count(), len(test_configs)))
    results = pool.map(_openai_payload, test_configs)

    # Check results and handle failures
    for filename, model, structured, success, error_message in results:
        print(f"Test for {filename}, {model}, structured_output={structured}: {'SUCCESS' if success else 'FAILED'}")

        if not success:
            if filename == "all_python_complex" and structured and "Expected ValueError" in error_message:
                pytest.fail(f"Failed for {filename} with {model}, structured_output={structured}: {error_message}")
            elif not (filename == "all_python_complex" and structured):
                pytest.fail(f"Failed for {filename} with {model}, structured_output={structured}: {error_message}")

    pool.close()
    pool.join()

    end_time = time.time()
    print(f"Total execution time: {end_time - start_time:.2f} seconds")


# Parallel implementation for Composio test
def _run_composio_test(action_name, openai_model, structured_output):
    """Run a single Composio test case in parallel"""
    try:
        tool_create = ToolCreate.from_composio(action_name=action_name)
        assert tool_create.json_schema
        schema = tool_create.json_schema

        if structured_output:
            tool_schema = convert_to_structured_output(schema)
        else:
            tool_schema = schema

        api_key = os.getenv("OPENAI_API_KEY")
        assert api_key is not None, "OPENAI_API_KEY must be set"

        system_prompt = "You job is to test the tool that you've been provided. Don't ask for any clarification on the args, just come up with some dummy data and try executing the tool."

        url = "https://api.openai.com/v1/chat/completions"
        headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
        data = {
            "model": openai_model,
            "messages": [
                {"role": "system", "content": system_prompt},
            ],
            "tools": [
                {
                    "type": "function",
                    "function": tool_schema,
                }
            ],
            "tool_choice": "auto",
            "parallel_tool_calls": False,
        }

        make_post_request(url, headers, data)
        return (action_name, True, None)  # Success
    except Exception as e:
        return (action_name, False, str(e))  # Failure with error message


@pytest.mark.parametrize("openai_model", ["gpt-4o-mini"])
@pytest.mark.parametrize("structured_output", [True])
def test_composio_tool_schema_generation(openai_model: str, structured_output: bool):
    """Test that we can generate the schemas for some Composio tools."""

    if not os.getenv("COMPOSIO_API_KEY"):
        pytest.skip("COMPOSIO_API_KEY not set")

    start_time = time.time()

    action_names = [
        "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER",  # Simple
        "CAL_GET_AVAILABLE_SLOTS_INFO",  # has an array arg, needs to be converted properly
        "SALESFORCE_RETRIEVE_LEAD_BY_ID",  # has an array arg, needs to be converted properly
        "FIRECRAWL_SEARCH",  # has an optional array arg, needs to be converted properly
    ]

    # Create a pool of processes
    pool = mp.Pool(processes=min(mp.cpu_count(), len(action_names)))

    # Map the work to the pool
    func = partial(_run_composio_test, openai_model=openai_model, structured_output=structured_output)
    results = pool.map(func, action_names)

    # Check results
    for action_name, success, error_message in results:
        print(f"Test for {action_name}: {'SUCCESS' if success else 'FAILED - ' + error_message}")
        assert success, f"Test for {action_name} failed: {error_message}"

    pool.close()
    pool.join()

    end_time = time.time()
    print(f"Total execution time: {end_time - start_time:.2f} seconds")


@pytest.mark.parametrize("openai_model", ["gpt-4o-mini"])
@pytest.mark.parametrize("structured_output", [True])
def test_langchain_tool_schema_generation(openai_model: str, structured_output: bool):
    """Test that we can generate the schemas for some Langchain tools."""
    from langchain_community.tools import WikipediaQueryRun
    from langchain_community.utilities import WikipediaAPIWrapper

    api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=500)
    langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)

    tool_create = ToolCreate.from_langchain(
        langchain_tool=langchain_tool,
        additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"},
    )

    assert tool_create.json_schema
    schema = tool_create.json_schema
    print(f"The schema for {langchain_tool.name}: {json.dumps(schema, indent=4)}\n\n")

    try:
        if structured_output:
            tool_schema = convert_to_structured_output(schema)
        else:
            tool_schema = schema

        api_key = os.getenv("OPENAI_API_KEY")
        assert api_key is not None, "OPENAI_API_KEY must be set"

        system_prompt = "You job is to test the tool that you've been provided. Don't ask for any clarification on the args, just come up with some dummy data and try executing the tool."

        url = "https://api.openai.com/v1/chat/completions"
        headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
        data = {
            "model": openai_model,
            "messages": [
                {"role": "system", "content": system_prompt},
            ],
            "tools": [
                {
                    "type": "function",
                    "function": tool_schema,
                }
            ],
            "tool_choice": "auto",
            "parallel_tool_calls": False,
        }

        make_post_request(url, headers, data)
        print(f"Successfully called OpenAI using schema generated from {langchain_tool.name}\n\n")
    except Exception:
        print(f"Failed to call OpenAI using schema generated from {langchain_tool.name}\n\n")
        raise


# Helper function for pydantic args schema test
def _run_pydantic_args_test(filename, openai_model, structured_output):
    """Run a single pydantic args schema test case"""
    try:
        # Import the module dynamically
        file_path = os.path.join(os.path.dirname(__file__), f"test_tool_schema_parsing_files/{filename}.py")
        spec = importlib.util.spec_from_file_location(filename, file_path)
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)

        # Find the function definition and args schema if defined
        last_function_name, last_function_source, last_model_class = None, None, None
        for name, obj in inspect.getmembers(module):
            if inspect.isfunction(obj) and obj.__module__ == module.__name__:
                last_function_name = name
                last_function_source = inspect.getsource(obj)  # only import the function, not the whole file
            if inspect.isclass(obj) and obj.__module__ == module.__name__ and issubclass(obj, BaseModel):
                last_model_class = obj

        # Get the ArgsSchema if it exists
        args_schema = None
        if last_model_class:
            args_schema = last_model_class.model_json_schema()

        tool = Tool(
            name=last_function_name,
            source_code=last_function_source,
            args_json_schema=args_schema,
        )
        schema = tool.json_schema

        # We expect this to fail for all_python_complex with structured_output=True
        if filename == "all_python_complex" and structured_output:
            try:
                convert_to_structured_output(schema)
                return (filename, False, "Expected ValueError but conversion succeeded")
            except ValueError:
                return (filename, True, None)  # This is expected

        # Make the API call
        if structured_output:
            tool_schema = convert_to_structured_output(schema)
        else:
            tool_schema = schema

        api_key = os.getenv("OPENAI_API_KEY")
        assert api_key is not None, "OPENAI_API_KEY must be set"

        system_prompt = "You job is to test the tool that you've been provided. Don't ask for any clarification on the args, just come up with some dummy data and try executing the tool."

        url = "https://api.openai.com/v1/chat/completions"
        headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
        data = {
            "model": openai_model,
            "messages": [
                {"role": "system", "content": system_prompt},
            ],
            "tools": [
                {
                    "type": "function",
                    "function": tool_schema,
                }
            ],
            "tool_choice": "auto",
            "parallel_tool_calls": False,
        }

        make_post_request(url, headers, data)
        return (filename, True, None)  # Success
    except Exception as e:
        return (filename, False, str(e))  # Failure with error message


@pytest.mark.parametrize("openai_model", ["gpt-4o"])
@pytest.mark.parametrize("structured_output", [True, False])
def test_valid_schemas_with_pydantic_args_schema(openai_model: str, structured_output: bool):
    """Test that we can send the schemas to OpenAI and get a tool call back."""

    start_time = time.time()

    filenames = [
        "pydantic_as_single_arg_example",
        "list_of_pydantic_example",
        "nested_pydantic_as_arg_example",
        "simple_d20",
        "all_python_complex",
        "all_python_complex_nodict",
    ]

    # Create a pool of processes
    pool = mp.Pool(processes=min(mp.cpu_count(), len(filenames)))

    # Map the work to the pool
    func = partial(_run_pydantic_args_test, openai_model=openai_model, structured_output=structured_output)
    results = pool.map(func, filenames)

    # Check results
    for filename, success, error_message in results:
        print(f"Test for {filename}: {'SUCCESS' if success else 'FAILED - ' + error_message}")

        # Special handling for expected failure
        if filename == "all_python_complex" and structured_output:
            assert success, f"Expected failure handling for {filename} didn't work: {error_message}"
        else:
            assert success, f"Test for {filename} failed: {error_message}"

    pool.close()
    pool.join()

    end_time = time.time()
    print(f"Total execution time: {end_time - start_time:.2f} seconds")


# Google comment style validation tests


# ---------- helpers ----------
def _check(fn, expected_regex: str | None = None):
    if expected_regex is None:
        # should pass
        validate_google_style_docstring(fn)
    else:
        with pytest.raises(ValueError, match=expected_regex):
            validate_google_style_docstring(fn)


# ---------- passing cases ----------
def good_function(file_requests: list, close_all_others: bool = False) -> str:
    """Open files.

    Args:
        file_requests (list): Requests.
        close_all_others (bool): Flag.

    Returns:
        str: Status.
    """
    return "ok"


def good_function_no_return(file_requests: list, close_all_others: bool = False) -> str:
    """Open files.

    Args:
        file_requests (list): Requests.
        close_all_others (bool): Flag.
    """
    return "ok"


def agent_state_ok(agent_state, value: int) -> str:
    """Ignores agent_state param.

    Args:
        value (int): Some value.

    Returns:
        str: Status.
    """
    return "ok"


class Dummy:
    def method(self, bar: int) -> str:  # keeps an explicit self
        """Bound-method example.

        Args:
            bar (int): Number.

        Returns:
            str: Status.
        """
        return "ok"


# ---------- failing cases ----------
def no_doc(x: int) -> str:
    return "fail"


def no_args(x: int) -> str:
    """Missing Args.

    Returns:
        str: Status.
    """
    return "fail"


def missing_param_doc(x: int, y: int) -> str:
    """Only one param documented.

    Args:
        x (int): X.

    Returns:
        str: Status.
    """
    return "fail"


# ---------- parametrized test ----------
@pytest.mark.parametrize(
    "fn, regex",
    [
        (good_function, None),
        (agent_state_ok, None),
        (Dummy.method, None),  # unbound method keeps `self`
        (good_function_no_return, None),
        (no_doc, "has no docstring"),
        (no_args, "must have 'Args:' section"),
        (missing_param_doc, "parameter 'y' not documented"),
    ],
)
def test_google_style_docstring_validation(fn, regex):
    _check(fn, regex)
